# -*-Python-*-

# The vanilla transformer model but with branched attention as described in the
# "Weighted Transformer" paper (https://arxiv.org/pdf/1711.02132.pdf)
include 'models/vanilla_transformer.gin'

import transformer_modifications.transformer_modifications.optimizers
import mesh_tensorflow.optimize
import mesh_tensorflow.transformer.learning_rate_schedules
import mesh_tensorflow.transformer.transformer_layers


utils.run.optimizer = @optimize.AdafactorWithMultiLRSchedule

# Removed the DenseReluDense layers, because BranchedSelfAttention applies the
# DenseReluDense layer before summing over the heads.
encoder/transformer.make_layer_stack.layers = [
    @mesh_tensorflow.transformer.transformer_layers.BranchedSelfAttention,
]
decoder/transformer.make_layer_stack.layers = [
    @mesh_tensorflow.transformer.transformer_layers.BranchedSelfAttention,
    @mesh_tensorflow.transformer.transformer_layers.BranchedEncDecAttention,
]

AdafactorWithMultiLRSchedule.variable_search = ["branched"]


AdafactorWithMultiLRSchedule.alt_lr_schedules = [[
    @learning_rate_schedules.slanted_triangular,
     @at5_optimizers.learning_rate_cut_off,
]]

# last 10K steps
at5_optimizers.learning_rate_cut_off.cut_off_step = 514288

optimize.compute_lr_for_step.train_steps = 524288
